## Run experiments
import argparse
import warnings
import random
import numpy as np
from typing import Dict
from tqdm import tqdm
from sklearn.exceptions import ConvergenceWarning

from src.utils import load_state_action_data, read_jsonl_file
from src.utils import compute_rel_rmse, save_rel_rmses, compute_mean_ci
from src.ope import run_ope

## Suppress warnings
warnings.filterwarnings(action="ignore", category=ConvergenceWarning)
warnings.filterwarnings(action="ignore", category=FutureWarning)

def run_experiment(job_name: str, config: Dict, random_state: int):
    estimate_pi_b = config['estimate_pi_b']
    datasets = config['datasets']
    num_loggers = config['num_loggers']
    n_fold = config['n_fold']
    exploration_probs = config['exploration_probs']
    stratum_ratio = config['stratum_ratio']

    rel_rmses = {}
    for dataset in datasets:
        state_data_df, action_data_df, num_features, num_actions = load_state_action_data(dataset)

        results = []
        for i in tqdm(range(len(action_data_df)), desc=f"{dataset}"):
            y_gt = action_data_df['y_gt'][i]
            y_eval_pol = action_data_df['y_eval_pol'][i]
            row_indices = action_data_df['row_indices'][i]
            states = np.array(state_data_df.iloc[row_indices, :num_features])

            ope_args = {
                'y_gt': y_gt,
                'y_eval_pol': y_eval_pol,
                'states': states,
                'num_actions': num_actions,
                'num_loggers': num_loggers,
                'exploration_probs': exploration_probs,
                'stratum_ratio': stratum_ratio,
                'n_fold': n_fold,
                'estimate_pi_b': estimate_pi_b,
                'random_state': random_state
            }
            result = run_ope(ope_args)
            J_gt = (y_gt == y_eval_pol).mean()
            result['J_gt'] = J_gt
            results.append(result)

        rel_rmse = compute_rel_rmse(results)
        rel_rmses[dataset] = rel_rmse

        ## Compute avg time & 95% CI
        print("------Rel-RMSE and Elapsed Times------")
        elapsed_times = {}
        for result in results:
            for key in result.keys():
                if key.startswith('time_'):
                    if key not in elapsed_times.keys():
                        elapsed_times[key] = [result[key]]
                    else:
                        elapsed_times[key].append(result[key])
        for key in elapsed_times.keys():
            mean_t, ci_t = compute_mean_ci(np.array(elapsed_times[key]))
            print(f"    Method: {key[5:]}")
            print(f"        Rel-RMSE: {rel_rmse[key[5:]]:.6f}, Time: {mean_t:.6f}, Time CI(95%): [{ci_t[0]:.6f},{ci_t[1]:.6f}]")
        print("---------------------------------------")
            
    save_rel_rmses(job_name, config, rel_rmses)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True)
    parser.add_argument("--random_state", "-r", type=int, default=42)
    args = parser.parse_args()
    print(args)

    np.random.seed(args.random_state)
    random.seed(args.random_state)
    if args.config.endswith('.jsonl'):
        job_name = args.config[:-6]
    elif args.config.endswith('.json'):
        job_name = args.config[:-5]
    else:
        raise ValueError("Job config must be either .json or .jsonl file")

    configs = read_jsonl_file(f'configs/{args.config}')
    for i, config in enumerate(configs):
        print(f"Running experiment: {i+1}/{len(configs)}")
        run_experiment(job_name, config, random_state = args.random_state + i)